{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "def separate_by_class(dataset):\n",
    "    separated = {}\n",
    "    for i in range(len(dataset)):\n",
    "        vector = dataset[i]\n",
    "        if (vector[-1] not in separated):\n",
    "            separated[vector[-1]] = []\n",
    "        separated[vector[-1]].append(vector)\n",
    "    return separated"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{0: [[1, 2, 3, 0], [1, 3, 2, 0]], 1: [[0, 3, 2, 1], [0, 2, 3, 1]]}"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = [[1, 2, 3, 0],\n",
    "        [1, 3, 2, 0],\n",
    "        [0, 3, 2, 1],\n",
    "        [0, 2, 3, 1]]\n",
    "\n",
    "testset = [[1, 3, 3, 0],\n",
    "           [0, 3, 3, 1]]\n",
    "\n",
    "separate_by_class(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 91,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "separated {0: [[1, 2, 3, 0], [1, 3, 2, 0]], 1: [[0, 3, 2, 1], [0, 2, 3, 1]]}\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{0: [(1.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)],\n",
       " 1: [(0.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)]}"
      ]
     },
     "execution_count": 91,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def mean(numbers):\n",
    "    return sum(numbers) / float(len(numbers))\n",
    "\n",
    "def stdev(numbers):\n",
    "    avg = mean(numbers)\n",
    "    variance = sum([pow(x - avg, 2) for x in numbers]) / float(len(numbers) - 1)\n",
    "    return np.sqrt(variance)\n",
    "\n",
    "def summary(dataset):\n",
    "    summaries = [(mean(attr), stdev(attr)) for attr in zip(*dataset)]\n",
    "    del summaries[-1]\n",
    "    return summaries\n",
    "\n",
    "def summarize_by_class(dataset):\n",
    "    separated = separate_by_class(dataset)\n",
    "    print(\"separated {}\".format(separated))\n",
    "    summaries = {}\n",
    "    for cls_val, instances in separated.items():\n",
    "        summaries[cls_val] = summary(instances)\n",
    "    return summaries\n",
    "\n",
    "summarize_by_class(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calc_prob(x, mean, stdev):\n",
    "    exp = np.exp(-(pow(x-mean, 2)/(2*pow(stdev, 2))))\n",
    "    return (1/(np.sqrt(2*np.pi)*stdev))*exp\n",
    "\n",
    "def calc_cls_prob(summaries, new_vector):\n",
    "    probs = {}\n",
    "    for cls_val, cls_summaries in summaries.items():\n",
    "        probs[cls_val] = 1\n",
    "        for i in range(len(cls_summaries)):\n",
    "            mean, stdev = cls_summaries[i]\n",
    "            probs[cls_val] *= calc_prob(new_vector[i], mean, stdev)\n",
    "    return probs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict(summaries, new_vector):\n",
    "    probs = calc_cls_prob(summaries, new_vector)\n",
    "    best_label, best_prob = None, -1\n",
    "    for cls_val, prob in probs.items():\n",
    "        if best_label is None or prob > best_prob:\n",
    "            best_label = cls_val\n",
    "            best_prob = prob\n",
    "    return best_label\n",
    "\n",
    "def get_pred(summaries, testset):\n",
    "    preds = []\n",
    "    for i in range(len(testset)):\n",
    "        result = predict(summaries, testset[i])\n",
    "        preds.append(result)\n",
    "    return preds\n",
    "\n",
    "def get_acc(testset, preds):\n",
    "    correct = 0\n",
    "    for x in range(len(testset)):\n",
    "        if testset[x][-1] == preds[x]:\n",
    "            correct += 1\n",
    "    return (correct / float(len(testset))) * 100.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1, 2, 3, 0], [1, 3, 2, 0], [0, 3, 2, 1], [0, 2, 3, 1]]"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{0: [(1.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)], 1: [(0.0, 0.0), (2.5, 0.7071067811865476), (2.5, 0.7071067811865476)]}\n",
      "[0, 0]\n",
      "50.0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:2: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  \n",
      "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:3: RuntimeWarning: divide by zero encountered in double_scalars\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n",
      "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:2: RuntimeWarning: divide by zero encountered in double_scalars\n",
      "  \n",
      "/home/lusx/anaconda3/envs/torch/lib/python3.6/site-packages/ipykernel_launcher.py:3: RuntimeWarning: invalid value encountered in double_scalars\n",
      "  This is separate from the ipykernel package so we can avoid doing imports until\n"
     ]
    }
   ],
   "source": [
    "summaries = summarize_by_class(data)\n",
    "print(summaries)\n",
    "preds = get_pred(summaries, testset)\n",
    "acc = get_acc(testset, preds)\n",
    "print(preds)\n",
    "print(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:torch]",
   "language": "python",
   "name": "conda-env-torch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
